import torch
import torch.nn
from image_to_latex.lit_models import LitResNetTransformer

ckpt_path = './artifacts/model.pt'
onnx_path = './artifacts/model.onnx'

lit_model = LitResNetTransformer.load_from_checkpoint(ckpt_path)
input_sample = torch.randn((1, 3, 500, 200))
lit_model.to_onnx(onnx_path, input_sample, export_params=True, opset_version=9)
